{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.Tensor.register_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "y.grad:  None\n",
      "y_grad[0]:  tensor([0.2500, 0.2500, 0.2500, 0.2500])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "e:\\Anaconda_data\\envs\\pt_1.11\\lib\\site-packages\\torch\\_tensor.py:1104: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\build\\aten\\src\\ATen/core/TensorBody.h:475.)\n",
      "  return self._grad\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "y_grad = list()\n",
    "def grad_hook(grad):\n",
    "    y_grad.append(grad)\n",
    "x = torch.tensor([2., 2., 2., 2.], requires_grad=True)\n",
    "y = torch.pow(x, 2)\n",
    "z = torch.mean(y)\n",
    "h = y.register_hook(grad_hook)\n",
    "z.backward()\n",
    "print(\"y.grad: \", y.grad)\n",
    "print(\"y_grad[0]: \", y_grad[0])\n",
    "h.remove()    # removes the hook"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到当z.backward()结束后，张量y中的grad为None，因为y是非叶子节点张量，在梯度反传结束之后，被释放。  \n",
    "在对张量y的hook函数（grad_hook）中，将y的梯度保存到了y_grad列表中，因此可以在z.backward()结束后，仍旧可以在y_grad[0]中读到y的梯度为tensor([0.2500, 0.2500, 0.2500, 0.2500])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([2., 2., 2., 2.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "def grad_hook(grad):\n",
    "    grad *= 2\n",
    "x = torch.tensor([2., 2., 2., 2.], requires_grad=True)\n",
    "y = torch.pow(x, 2)\n",
    "z = torch.mean(y)\n",
    "h = x.register_hook(grad_hook)\n",
    "z.backward()\n",
    "print(x.grad)\n",
    "h.remove()    # removes the hook\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "原x的梯度为tensor([1., 1., 1., 1.])，经grad_hook操作后，梯度为tensor([2., 2., 2., 2.])。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## torch.nn.Module.register_forward_hook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "code_folding": [
     2
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "output shape: torch.Size([1, 2, 1, 1])\n",
      "output value: tensor([[[[ 9.]],\n",
      "\n",
      "         [[18.]]]], grad_fn=<MaxPool2DWithIndicesBackward0>)\n",
      "\n",
      "feature maps shape: torch.Size([1, 2, 2, 2])\n",
      "output value: tensor([[[[ 9.,  9.],\n",
      "          [ 9.,  9.]],\n",
      "\n",
      "         [[18., 18.],\n",
      "          [18., 18.]]]], grad_fn=<ThnnConv2DBackward0>)\n",
      "\n",
      "input shape: torch.Size([1, 1, 4, 4])\n",
      "input value: (tensor([[[[1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1.],\n",
      "          [1., 1., 1., 1.]]]]),)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 2, 3)\n",
    "        self.pool1 = nn.MaxPool2d(2, 2)\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = self.pool1(x)\n",
    "        return x\n",
    "def farward_hook(module, data_input, data_output):\n",
    "    fmap_block.append(data_output)\n",
    "    input_block.append(data_input)\n",
    "    \n",
    "# 初始化网络\n",
    "net = Net()\n",
    "net.conv1.weight[0].detach().fill_(1)\n",
    "net.conv1.weight[1].detach().fill_(2)\n",
    "net.conv1.bias.data.zero_()\n",
    "\n",
    "# 注册hook\n",
    "fmap_block = list()\n",
    "input_block = list()\n",
    "net.conv1.register_forward_hook(farward_hook)\n",
    "\n",
    "# inference\n",
    "fake_img = torch.ones((1, 1, 4, 4))   # batch size * channel * H * W\n",
    "output = net(fake_img)\n",
    "\n",
    "# 观察\n",
    "print(\"output shape: {}\\noutput value: {}\\n\".format(output.shape, output))\n",
    "print(\"feature maps shape: {}\\noutput value: {}\\n\".format(fmap_block[0].shape, fmap_block[0]))\n",
    "print(\"input shape: {}\\ninput value: {}\".format(input_block[0][0].shape, input_block[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先初始化一个网络，卷积层有两个卷积核，权值分别为全1和全2，bias设置为0，池化层采用2*2的最大池化。  \n",
    "\n",
    "在进行forward之前对module——conv1注册了forward_hook函数，然后执行前向传播（output=net(fake_img)），当前向传播完成后， fmap_block列表中的第一个元素就是conv1层输出的特征图了。  \n",
    "\n",
    "这里注意观察farward_hook函数有data_input和data_output两个变量，特征图是data_output这个变量，而data_input是conv1层的输入数据， conv1层的输入是一个tuple的形式。\t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pt_1.11",
   "language": "python",
   "name": "pt_1.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
